#!/usr/bin/env python2
# -*- coding: utf-8 -*-
# vim: set ft=python:
# Copyright (c) 2016 Mirantis Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This module should be placed in /etc/ansible/facts.d/ceph.facts on remote
host.

This should be Python 2 script. And Jinja2 template, with following context:

{
    "cluster": "cluster"
}
"""


from __future__ import absolute_import
from __future__ import print_function
from __future__ import unicode_literals

import functools
import json
import logging
import os
import re
import subprocess
import sys
import time


CLUSTER_NAME = "{{ cluster }}"
"""Name of the cluster."""

CEPH_DISK_LIST_LINE = re.compile(
    r"^\s*(?P<data_device>\S+)\s+ceph\sdata,"
    r"\s*\w+,"
    r"\s*cluster\s*(?P<cluster>\w+),"
    r"\s*(?P<osd_name>\S+),"
    r"\s*journal\s*(?P<journal_device>\S+)\s*$",
    re.UNICODE
)

PROCESS_WAIT_TIMEOUT = 10
"""Timeout of waiting for process output."""

PROCESS_GENTLE_KILL_TIMEOUT = 3
"""Timeout of waiting for process to be killed."""


if sys.version_info >= (3,):
    unicode = lambda item: item.decode("utf-8")


def handle_exceptions(func):
    @functools.wraps(func)
    def decorator(*args, **kwargs):
        try:
            func(*args, **kwargs)
        except Exception as exc:
            logging.exception("Cannot collect facts: %s", exc)
            print_json(
                {"osd_tree": {}, "osd_partitions": {}, "error": unicode(exc)}
            )
            return os.EX_SOFTWARE
        return os.EX_OK

    return decorator


@handle_exceptions
def main():
    logging.basicConfig(level=logging.DEBUG)

    response = {
        "osd_tree": get_osd_tree(),
        "osd_partitions": get_osd_partitions(),
        "error": None
    }
    print_json(response)


def get_osd_tree():
    output = get_json_output(
        ["ceph", "--cluster", CLUSTER_NAME, "osd", "tree", "-f", "json"],
        wait_stderr=False
    )
    output = {item["id"]: item for item in output["nodes"]}
    osd_tree = {}

    for item in output.values():
        if item["type"] == "host":
            osd_tree[item["name"]] = [output[_id] for _id in item["children"]]

    return osd_tree


def get_osd_partitions():
    output = get_plain_output(["ceph-disk", "list"])
    output = [line.strip() for line in output.split("\n") if line.strip()]
    partitions = {}

    for line in output:
        matcher = CEPH_DISK_LIST_LINE.match(line)
        if not matcher:
            continue

        matcher = matcher.groupdict()
        partitions.setdefault(matcher["cluster"], {})[matcher["osd_name"]] = {
            "journal": matcher["journal_device"],
            "data": matcher["data_device"]
        }

    return partitions


def get_json_output(command, wait_stderr=False):
    output = get_plain_output(command, wait_stderr).strip()
    if not output:
        logging.error("Output of %s is empty", command)
        raise ValueError("Output of {0} is empty".format(command))

    return json.loads(output)


def get_plain_output(command, wait_stderr=False):
    with open(os.devnull) as devnull:
        process = subprocess.Popen(
            command,
            stderr=subprocess.PIPE, stdout=subprocess.PIPE, stdin=devnull)
        stdout, stderr = wait_for_process(command, process)

    logging.debug("Stdout of %s: %s", command, stdout)
    logging.warning("Stderr of %s: %s", command, stderr)

    if process.returncode != os.EX_OK:
        logging.error("Command %s has been finished with exit code %d",
                      command, process.returncode)

    return stderr if wait_stderr else stdout


def wait_for_process(command, process):
    logging.debug("Run %s. Pid %s", command, process.pid)

    try:
        return wait_for_tmo(process, PROCESS_WAIT_TIMEOUT)
    except ValueError:
        logging.warning("Timeout for %s (pid %s) is expired. Send SIGTERM.",
                        command, process.pid)

    process.terminate()
    try:
        wait_for_tmo(process, PROCESS_GENTLE_KILL_TIMEOUT)
    except ValueError:
        logging.error(
            "Process %s (pid %s) is still running after SIGTERM. "
            "Send SIGKILL.",
            command, process.pid)
        process.kill()

    raise ValueError("Command {0} hangs.".format(command))


def wait_for_tmo(process, timeout):
    current_time = time.time()
    finish_time = current_time + timeout

    while current_time <= finish_time:
        if process.poll() is not None:
            return process.communicate()
        time.sleep(0.5)
        current_time = time.time()

    if process.poll() is not None:
        return process.communicate()

    raise ValueError("Still running.")


def print_json(data):
    print(json.dumps(data, indent=4, sort_keys=True))


if __name__ == "__main__":
    sys.exit(main())
